-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General plugin mechanism #45355
General plugin mechanism #45355
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
✅ This PR's description meets the template requirements! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unit testing
plugindata.data = &value; | ||
} else { | ||
CHECK(false) << "not incompleted"; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not support float, and other dtype ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充完整
plugindatas.push_back(plugindata); | ||
} | ||
|
||
nvinfer1::PluginFieldCollection pluginFC{(int32_t)plugindatas.size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
“pluginFC” should lowercase with underscore, see https://google.github.io/styleguide/cppguide.html#Variable_Names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
} | ||
|
||
auto creator = | ||
GetPluginRegistry()->getPluginCreator(op_desc.Type().c_str(), "1"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments for why setting plugin version to “1”
auto *var = block_desc.FindVar(arg_name); | ||
PADDLE_ENFORCE_NOT_NULL( | ||
var, | ||
platform::errors::NotFound("no variable called %s in block.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no variable called -> There is no variable called
class OpDesc; | ||
} // namespace proto | ||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove line 21 to line 28
const nvinfer1::DimsExprs* inputs, | ||
int nb_inputs, | ||
nvinfer1::IExprBuilder& expr_builder, // NOLINT | ||
const framework::OpDesc& op_desc_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_desc_ -> op_desc
return false; | ||
} | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GenericPlugin对op的输入、输出类型也有限制,这里是否要增加判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据 op 做具体判断,后续继续添加通用plugin支持的op时,根据op的具体情况做判断
auto* attr_ptr = attr_reader.GetAttr(attr_name); | ||
switch (attr_defs[k].type_index) { | ||
case phi::AttributeType::SCALAR: | ||
if (attr_ptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 if(attr_prt) 判断是否需要提前到49行之前,下面每个case都加了判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
} | ||
|
||
template <typename T> | ||
inline std::string vectorToStr(const std::vector<T>& dims) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vectorToStr -> VectorToStr
outputs_data_type_ = outputs_data_type; | ||
} | ||
|
||
GenericPlugin::GenericPlugin(void const* serialData, size_t serialLength) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serialData -> serial_data , serialLength -> serial_len
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
if ((desc.HasAttr("namescope") && | ||
PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) == | ||
"/skip_quant_2/") || | ||
desc.HasAttr("skip_quant")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 "skip_quant_2" 太hard code了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码是原有的代码。这里我把这些if判断从 op_conveter.h 文件移动到了 opteller中。所以git 显示是我改的。实际上这里逻辑还是原来的逻辑
// only consider dynamic_shape mode | ||
if (!with_dynamic_shape) { | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这为啥只考虑 dynamic_shape mode?
break; | ||
|
||
default: | ||
CHECK(false) << "no OpConverter for optype " << op_desc.Type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输出log信息需要语法正确
op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr)); | ||
// op_meta_data_ | ||
proto_op_desc_.SerializeToString(&op_meta_data_); | ||
// inputs_data_type_ and outputs_data_type_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些 // 注释看起来无意义吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
op_meta_data_ = std::move(op_meta_data); | ||
// proto_op_desc_ | ||
proto_op_desc_.ParseFromString(op_meta_data_); | ||
// op_desc_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
yield self.create_inference_config(), (0, 4), 1e-5 | ||
self.trt_param.precision = paddle_infer.PrecisionType.Half | ||
yield self.create_inference_config(), (0, 4), 1e-5 | ||
yield self.create_inference_config(), (1, 3), 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里去掉Half测试的原因是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经改回去了
for (auto &attr_name : op_attrs_names) { | ||
nvinfer1::PluginField plugindata; | ||
plugindata.name = attr_name.c_str(); | ||
if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
貌似可以用宏折叠起来。其他switch case 同
namespace inference { | ||
namespace tensorrt { | ||
|
||
nvinfer1::DimsExprs GatherNdInferMeta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有的OpInferMeta都要写在这个文件中吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯 集中放在这个文件中
free(dense_tensor_inputs); | ||
free(dense_tensor_outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里用free可以释放掉吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是可以的, 只需要释放申请的一个vector对象,vector对象里面的元素,vector在释放的过程中会去释放
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
析构函数不会被正确调用,有内存泄露风险
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改为delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
a3a20a4
PR types
New features
PR changes
Others
Describe
通用 Plugin 以及 自定义Op Plugin 加载机制
方案要点:
当前paddle trt 工作流程图如下:
这部分的主要功能就是利用 xxOpConverter 对象将Paddle Op 转换为 TRT Layer。新方案新增两个 converter(creater),能够使以下算子进入TRT。
第一类Op:该Op被通用 Plugin 所支持;
第二类Op:用户自定义Op
新增两个 converter:
generic_and_custom_plugin_creater.cc
generic_plugin_creater :该 converter 为一些 Op 创建通用Plugin;
custom_plugin_creater : 该 converter 为自定义 Op 创建用户自定义的Plugin;
调整 了OpTell;
对于 paddle 已有的大量 op converter, 加上新增的两个 converter,已有三种converter。对于一个 paddle op 需要使用哪种 converter, 由 OpTeller 判断。
之前 OpTeller 的功能是判断一个 Paddle Op , TRT 支持与否, 其中大量的 if 判断放在了
op_converter.h
文件中,这不太合理,op_converter.h 只负责根据Op名字,找到对应的 converter, 而能不能转的 if 等边界条件判断,应当移入op_tell.cc
中。由于现在存在三种 converter, 因此 OpTeller 不仅要告知一个 Paddle Op , TRT 支持与否,还需要告知使用哪一种 converter。
Default 表示使用框架内部已有的 xx_op_converter。
测试
1. 测试 custom_plugin_creater
通过在
dynamic_shape_infermeta.cc
文件中增加 符号化的 shape 推导函数,使得对应的 Op 被通用 plugin 所支持。在已有的auto_scan的单测中便能够测试 通用plugin对该 Op的支持情况。
如在
dynamic_shape_infermeta.cc
中增加了gather_nd
op 的符号化shape推导函数, 在test_trt_convert_gather_nd.py
单测中就会使用 通用 plugin。2. 测试 custom_plugin_creater
test_custom_plugin_creater.cc
test_custom_op_plugin.cc
测试 custom_plugin_creater 的功能:能否加载用户自定义 Op 的 Plugin。主要测试 custom_plugin_creater 能否正确加载到用户自定义Op, 获取Op的各种属性信息,以及从 IPluinRegistry 中得到正确的 plugin_creator, 将自定义Op的属性正确传给
plugin_creator,然后成功创建自定义Op的plugin。
规定:
用户自定义Op的 静态shape plugin,以名称 “_paddle_trt_plugin”结尾; 动态shape plugin 以 “_paddle_trt_dynamic_plugin”结尾。
在
test_custom_op_plugin.cc
文件下为 custom_op 定义了一个静态shape plugin 和 动态 shape plugin。这两个plugin 直接继承 public nvinfer1::IPluginV2 ,public nvinfer1::IPluginV2DynamicExt,
两个plugin 都没有具体的计算逻辑。主要实现的接口为:
该 api 负责接收自定义Op的属性信息,然后创建 plugin 对象。